import torch
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import scipy
from IPython.core.pylabtools import figsize
from joblib import Parallel, delayed
from itertools import product


from PolicyLearning.Utility import LoadData2, GetRiskGain,  CD_234_shared2, GetRiskGain2, TPSort, PL_4
from PolicyLearning.Utility import PL_4

def PDP_level3(X_input, table3_array):
    temp = np.empty_like(X_input)
    temp[:] = X_input[:]
    result = np.empty(shape=(3, 5))
    for i in range(3):
        for j in range(5):
            temp[:, i] = j
            result[i, j] = table3_array[temp[:,0]*25+temp[:,1]*5+temp[:,2]].mean()
            temp[:, i] = X_input[:, i]
    return result+1


torch.set_default_tensor_type('torch.DoubleTensor')

table2_array = np.load('Database/table2_array.npy') - 1
table3_array = np.load('Database/table3_array.npy') - 1
table2_array_bin = np.zeros((25,5), dtype=int)
table3_array_bin = np.zeros((125,5), dtype=int)
table2_cur = np.array(table2_array).reshape((5,5))
table3_cur = np.array(table3_array).reshape((5,5,5))
for i in range(25):
    table2_array_bin[i][table2_array[i]] = 1
for i in range(125):
    table3_array_bin[i][table3_array[i]] = 1


DataName = 'Data1969_09'
OutcomeCol1 = 'sec_c'
OutcomeCol2 = 'econ_c'
OutcomeCol3 = 'soccap_c'
# LoadData
dt = LoadData2(DataName)
Safety_Score = dt.loc[:,'my4'].values - 1
temp1 = np.array([0,1,2,3,4], dtype=int)
temp4 = np.array([25*i+5*j+k for i,j,k in product(range(5),range(5),range(5))])

R = 4000
num_rep = 2000
Matern_length_scale = 1
Matern_nu = 1.5
sig_latent = 2
sig_noise = 1
is_rescale = ''
myfilename1 = '_'.join([is_rescale+DataName,OutcomeCol1,'TotalResult',str(R), str(int(100*Matern_length_scale)), str(int(100*Matern_nu)), str(int(100*sig_latent)), str(int(100*sig_noise))]) + '.pt'
myfilename2 = '_'.join([is_rescale+DataName,OutcomeCol2,'TotalResult',str(R), str(int(100*Matern_length_scale)), str(int(100*Matern_nu)), str(int(100*sig_latent)), str(int(100*sig_noise))]) + '.pt'
TotalResult1 = torch.load("Database/MCMCSample/"+myfilename1,map_location=torch.device('cpu'))[:,(4000-num_rep):,:] # 5*R*N tensor
TotalResult2 = torch.load("Database/MCMCSample/"+myfilename2,map_location=torch.device('cpu'))[:,(4000-num_rep):,:] # 5*R*N tensor
TotalResult = TotalResult1
Risk_table, Gain_table = GetRiskGain2(TotalResult)



X4_input = pd.DataFrame(dt.loc[:,['3a','3b','3c']].values-1, columns=['V1','V2','V3'])
X3a_input = pd.DataFrame(dt.loc[:,['2a','2b','mod1c_num']].values-1, columns=['V1','V2','V3'])
X3b_input = pd.DataFrame(dt.loc[:,['2d','2c','mod1m_num']].values-1, columns=['V1','V2','V3'])
X3c_input = pd.DataFrame(dt.loc[:,['2e','2f']].values-1, columns=['V1','V2'])
X2a_input = pd.DataFrame(dt.loc[:,['mod1a_num','mod1b_num']].values-1, columns=['V1','V2'])
X2b_input = pd.DataFrame(dt.loc[:,['mod1e_num','mod1d_num', '1f1g']].values-1, columns=['V1','V2','V3'])
X2c_input = pd.DataFrame(dt.loc[:,['mod1h_num','mod1i_num']].values-1, columns=['V1','V2'])
X2d_input = pd.DataFrame(dt.loc[:,['mod1j_num','1l1k', 'mod1g_num']].values-1, columns=['V1','V2','V3'])
X2e_input = pd.DataFrame(dt.loc[:,['mod1o_num','mod1n_num', 'mod1p_num']].values-1, columns=['V1','V2','V3'])
X2f_input = pd.DataFrame(dt.loc[:,['mod1r_num','mod1q_num', 'mod1s_num']].values-1, columns=['V1','V2','V3'])

# Sanity check
if sum([(table3_array[X2b_input.iloc[:,0]*25+X2b_input.iloc[:,1]*5+X2b_input.iloc[:,2]]!=dt['2b']-1).sum(),
(table3_array[X2d_input.iloc[:,0]*25+X2d_input.iloc[:,1]*5+X2d_input.iloc[:,2]]!=dt['2d']-1).sum(),
(table3_array[X2e_input.iloc[:,0]*25+X2e_input.iloc[:,1]*5+X2e_input.iloc[:,2]]!=dt['2e']-1).sum(),
(table3_array[X2f_input.iloc[:,0]*25+X2f_input.iloc[:,1]*5+X2f_input.iloc[:,2]]!=dt['2f']-1).sum(),
(table3_array[X3a_input.iloc[:,0]*25+X3a_input.iloc[:,1]*5+X3a_input.iloc[:,2]]!=dt['3a']-1).sum(),
(table3_array[X3b_input.iloc[:,0]*25+X3b_input.iloc[:,1]*5+X3b_input.iloc[:,2]]!=dt['3b']-1).sum(),
(table3_array[X4_input.iloc[:,0]*25+X4_input.iloc[:,1]*5+X4_input.iloc[:,2]]!=dt['my4']-1).sum(),
(table2_array[X2a_input.iloc[:,0]*5+X2a_input.iloc[:,1]]!=dt['2a']-1).sum() ,
(table2_array[X2c_input.iloc[:,0]*5+X2c_input.iloc[:,1]]!=dt['2c']-1).sum() ,
(table2_array[X3c_input.iloc[:,0]*5+X3c_input.iloc[:,1]]!=dt['3c']-1).sum() ]) ==0:
    print('Passed sanity check')
else:
    print("Didn't pass the check!!!!!!!!!!!!!!!!!!!!!!!!")

def Simu_onelyer(epsilon, penalty=1e-5):
    m, rawtable = PL_4(X4_input,Risk_table,Gain_table,R,epsilon,penalty=penalty)
    newtable_binary = rawtable.X.astype(int)
    newtable_integer = newtable_binary.reshape((5,5,5))
    newdecision = newtable_binary[X4_input['V1'] * 25 + X4_input['V2'] * 5 + X4_input['V3']]
    olddecision = table3_array[X4_input['V1'] * 25 + X4_input['V2'] * 5 + X4_input['V3']]
    PDP_matrix_new = PDP_level3(X4_input.values, newtable_binary)
    PDPi_new = PDP_matrix_new.std(axis=1)
    return newtable_integer, m.ObjVal, PDP_matrix_new, ((newtable_binary-table3_array)!=0).mean(),((newdecision-olddecision)!=0).mean(), PDPi_new

Epsilon = [np.exp(-0.1*n) for n in range(96)]
myresult = Parallel(n_jobs=12)(delayed(Simu_onelyer)(i) for i in Epsilon)
PAUI = np.array([i[1] for i in myresult])
PDPi = np.array([i[5] for i in myresult])
PAU = scipy.special.expit(torch.stack([TotalResult[0,:,:],TotalResult[[0,1],:,:].sum(axis=0),TotalResult[[0,1,2],:,:].sum(axis=0),
       TotalResult[[0,1,2,3],:,:].sum(axis=0),TotalResult[[0,1,2,3,4],:,:].sum(axis=0)])).mean(
    axis=1)[dt.my4.values-1].diag()
TotalResultD = scipy.special.expit(torch.stack([TotalResult[0,:,:],TotalResult[[0,1],:,:].sum(axis=0),TotalResult[[0,1,2],:,:].sum(axis=0),
         TotalResult[[0,1,2,3],:,:].sum(axis=0),TotalResult[[0,1,2,3,4],:,:].sum(axis=0)]))
Changed_Element = [i[3] for i in myresult]
Changed_decision = [i[4] for i in myresult]
PDP_matrix_old = PDP_level3(X4_input.values, table3_array)



score_old = table3_array[X4_input['V1'] * 25 + X4_input['V2'] * 5 + X4_input['V3']]
eu = TotalResultD[score_old[None, :], np.arange(num_rep)[:, None], np.arange(1954)]
post = np.array([((TotalResultD[myresult[i][0].reshape(-1)[X4_input['V1'] * 25 + X4_input['V2'] * 5 + X4_input['V3']][None, :], np.arange(num_rep)[:,None], np.arange(1954)]) - eu) for i in range(96)]).mean(axis=2)

myresult[0][0] - table3_array.reshape((5,5,5))

table3_array.reshape((5,5,5))
from scipy.stats import norm

def plot11_CI():
    baseline = PAU.mean().numpy()
    fig, ax1 = plt.subplots()
    myfontsize = 14
    ax1.set_ylabel('Posterior expected utility', fontsize=myfontsize)
    ax1.plot(Epsilon, baseline+(np.quantile(post,.025,axis=1)), color='green' , linewidth=2, label='2.5% posterior quantile utility')
    ax1.plot(Epsilon, baseline+post.mean(axis=1), color='b', linewidth=2, label='posterior expected utility')
    ax1.plot(Epsilon, baseline+(np.quantile(post,.975,axis=1)), color='red', linewidth=2, label='97.5% posterior quantile utility')
    ax1.set_xlabel('Epsilon in the chance constraint', fontsize=myfontsize)
    ax1.axhline(y=baseline, color='black',linestyle='--', linewidth=2, label='baseline policy')
    plt.xscale('log')
    fig.legend(loc='center',bbox_to_anchor=(0.39, 0.84), fontsize=11)
    fig.tight_layout()
    fig.show()
    fig.savefig('PolicyLearning/fig1/f611-11.pdf')


def plot12():
    fig, ax1 = plt.subplots()
    myfontsize = 14
    name2 = 'Proportion of modified elements in the learned 3-way table'
    ax1.set_ylabel('Proportion of modified elements', fontsize=myfontsize)
    ax1.plot(Epsilon,Changed_Element, color='g', label=name2, linewidth=2)
    ax1.set_xlabel('Epsilon in the chance constraint', fontsize=myfontsize)
    ax1.axhline(y=0, color='g', linestyle='--', label='posterior expected utility (baseline policy)',
                linewidth=2)
    ax1.text(0.05, 0.01, 'baseline policy', fontsize=myfontsize)
    ax1.text(0.06, 0.1, 'learned policy', fontsize=myfontsize)
    plt.xscale('log')
    fig.tight_layout()
    fig.show()
    fig.savefig('PolicyLearning/fig1/f611-12.pdf')



plot11()
plot12()

def plot2():
    plt.rc('text', usetex=True)
    plt.rc('text.latex', preamble=r'\usepackage{amsmath,amssymb,wasysym}')
    m, rawtable = PL_4(X4_input,Risk_table,Gain_table,R,0.1,penalty=1e-5)
    newtable_binary = rawtable.X.astype(int)
    PDP_matrix_new = PDP_level3(X4_input.values, newtable_binary)
    names = ['military', 'political', 'socioeconmic']
    fontsize = 30
    linewidth = 5
    markersize=200
    fig, ax1 = plt.subplots()
    fig.set_size_inches(16,9.5)
    box = ax1.get_position()
    ax1.set_position([box.x0, box.y0+box.height*0.1, box.width, box.height*0.9])
    width = 0.27
    shift = 0.1
    Xloc = np.array([1, 2, 3, 4, 5])
    ax1.set_ylabel('Marginal expectation of the output security score',fontsize=fontsize/1.2)
    for i in range(3):
        for j in range(5):
            ax1.plot([Xloc[j]+(i-1) * width]*2,[0,PDP_matrix_old[i,j]],alpha=0.5, linestyle='--',c=['b','g','r'][i], linewidth=linewidth)
    for i in range(3):
        for j in range(4):
            ax1.plot([Xloc[j]+(i-1) * width+shift]*2,[0,PDP_matrix_new[i,j]],c=['b','g','r'][i], linewidth=linewidth)
        j = 4
        ax1.plot([Xloc[j] + (i - 1) * width + shift] * 2, [0, PDP_matrix_new[i, j]], c=['b', 'g', 'r'][i],
             linewidth=linewidth, label=names[i]+' score (baseline $\\square$ learned $\\ocircle$)')

    ax1.scatter(Xloc - width, PDP_matrix_old[0, :],
                c='b', s=markersize,marker='s',alpha=0.5)
    ax1.scatter(Xloc, PDP_matrix_old[1, :], c='g',
                alpha=0.5,s=markersize,marker='s')
    ax1.scatter(Xloc + width, PDP_matrix_old[2, :], c='r',
                alpha=0.5,s=markersize,marker='s')
    ax1.set_xlabel('Input level-3 scores', fontsize=fontsize)
    ax1.tick_params(axis='both', which='major', labelsize=fontsize)
    ax1.set_ylim((0,5))
    ax1.scatter(Xloc-width+shift , PDP_matrix_new[0, :],
                c='b',s=markersize,marker='o')
    ax1.scatter(Xloc+shift , PDP_matrix_new[1, :],
                c='g',s=markersize,marker='o')
    ax1.scatter(Xloc + width + shift, PDP_matrix_new[2, :],
                c='r',marker='o', s=markersize)
    plt.title('Partial Dependent Plot of the baseline policy and learned policy ($\epsilon=0.1$)', fontsize=fontsize)
    fig.legend(loc='upper left', bbox_to_anchor=(0.02, 0.0,0.1,0.1),ncol=3, fontsize=fontsize/1.6)
    # fig.tight_layout()
    fig.show()
    fig.savefig('PolicyLearning/fig1/f611-2.pdf')

def plot22():
    plt.rc('text', usetex=True)
    plt.rc('text.latex', preamble=r'\usepackage{amsmath,amssymb,wasysym}')
    m, rawtable = PL_4(X4_input,Risk_table,Gain_table,R,0.1,penalty=1e-5)
    newtable_binary = rawtable.X.astype(int)
    PDP_matrix_new = PDP_level3(X4_input.values, newtable_binary)
    names = ['military', 'political', 'socioeconmic']
    fontsize = 30
    fig, ax1 = plt.subplots()
    fig.set_size_inches(16,10)
    width = 0.3
    Xloc = np.array([1, 2, 3, 4, 5])
    ax1.set_ylabel('Relative change of the partial dependence function',fontsize=fontsize/1.2)
    for i in range(3):
        ax1.bar(Xloc+(i-1) * width,PDP_matrix_new[i,:] / PDP_matrix_old[i,:] - 1 + 0.001,label= names[i] + ' score',color=['blue','green','red'][i],width=width)

    ax1.set_xlabel('Input level-3 scores', fontsize=fontsize)
    ax1.tick_params(axis='both', which='major', labelsize=fontsize)
    ax1.set_ylim((-.1,0.45))
    fig.legend(loc='upper left', bbox_to_anchor=(0.72, 0.84,0.1,0.1),ncol=1, fontsize=fontsize/1.2)
    fig.tight_layout()
    fig.show()
    fig.savefig('PolicyLearning/fig1/f611-2_1.pdf')


plot22()

TotalResult = TotalResult1
Risk_table, Gain_table = GetRiskGain2(TotalResult)
def plot31():
    Epsilon = np.flip(np.array([np.exp(-2.3-i*0.1) for i in range(30)]))
    myresult = Parallel(n_jobs=12)(delayed(lambda i: Simu_onelyer(i,2.5e-5))(i) for i in Epsilon)
    PDPi = [i[5] for i in myresult]
    PDPi_old = PDP_matrix_old.std(axis=1) / PDP_matrix_old.std(axis=1).sum()
    PDPi_scale = np.array(PDPi)/np.array(PDPi).sum(axis=1).reshape((-1,1))
    fontsize = 25
    linewidth = 4
    fig,ax = plt.subplots()
    fig.set_size_inches(12,8)
    box = ax.get_position()
    ax.plot(Epsilon,PDPi_scale[:,0], c='b',label='learned policy', linewidth=linewidth)
    ax.plot(Epsilon,PDPi_scale[:,1], c='r', linewidth=linewidth)
    ax.plot(Epsilon,PDPi_scale[:,2], c='g', label='learned policy', linewidth=linewidth)
    ax.plot(Epsilon,np.repeat(PDPi_old[0], Epsilon.__len__()),c='b', linestyle='--',linewidth=linewidth,
            label='baseline policy')
    ax.plot(Epsilon, np.repeat(PDPi_old[1], Epsilon.__len__()), c='r', linestyle='--', linewidth=linewidth,
            label='baseline policy')
    ax.plot(Epsilon, np.repeat(PDPi_old[2], Epsilon.__len__()), c='g', linestyle='--', linewidth=linewidth,
            label='baseline policy')
    ax.text(0.005,0.45,'Military',c='b',fontsize=fontsize)
    ax.text(0.005,0.34,'Political',c='r',fontsize=fontsize)
    ax.text(0.005,0.24,'Socioeconomic',c='g',fontsize=fontsize)
    ax.set_xlabel('Epsilon in the Bayesian safe policy learning',fontsize=fontsize)
    ax.set_ylabel('Scaled PD importance',fontsize=fontsize)
    ax.tick_params(axis='both', which='major', labelsize=fontsize)
    plt.xscale('log')
    fig.tight_layout()
    fig.show()
    fig.savefig('PolicyLearning/fig1/f615_sens2.pdf')

plot31()

TotalResult = TotalResult2
Risk_table, Gain_table = GetRiskGain2(TotalResult2)
def plot32():

    Epsilon = np.flip(np.array([np.exp(-2.3 - i * 0.1) for i in range(30)]))
    myresult = Parallel(n_jobs=12)(delayed(Simu_onelyer)(i) for i in Epsilon)
    PDPi = [i[5] for i in myresult]
    PDPi_old = PDP_matrix_old.std(axis=1) / PDP_matrix_old.std(axis=1).sum()
    PDPi_scale = np.array(PDPi) / np.array(PDPi).sum(axis=1).reshape((-1, 1))
    fontsize = 25
    linewidth = 4
    fig, ax = plt.subplots()
    fig.set_size_inches(12, 8)
    box = ax.get_position()
    ax.plot(Epsilon, PDPi_scale[:, 0], c='b', label='learned policy', linewidth=linewidth)
    ax.plot(Epsilon, PDPi_scale[:, 1], c='r', linewidth=linewidth)
    ax.plot(Epsilon, PDPi_scale[:, 2], c='g', label='learned policy', linewidth=linewidth)
    ax.plot(Epsilon, np.repeat(PDPi_old[0], Epsilon.__len__()), c='b', linestyle='--', linewidth=linewidth,
            label='baseline policy')
    ax.plot(Epsilon, np.repeat(PDPi_old[1], Epsilon.__len__()), c='r', linestyle='--', linewidth=linewidth,
            label='baseline policy')
    ax.plot(Epsilon, np.repeat(PDPi_old[2], Epsilon.__len__()), c='g', linestyle='--', linewidth=linewidth,
            label='baseline policy')
    ax.text(0.005, 0.48, 'Military', c='b', fontsize=fontsize)
    ax.text(0.005, 0.34, 'Political', c='r', fontsize=fontsize)
    ax.text(0.005, 0.20, 'Socioeconomic', c='g', fontsize=fontsize)
    ax.set_xlabel('Epsilon in the Bayesian safe policy learning', fontsize=fontsize)
    ax.set_ylabel('Scaled PD importance', fontsize=fontsize)
    ax.tick_params(axis='both', which='major', labelsize=fontsize)
    plt.xscale('log')
    fig.tight_layout()
    fig.show()
    fig.savefig('PolicyLearning/fig1/f616.pdf')


plot32()


TotalResult = TotalResult3
Risk_table, Gain_table = GetRiskGain2(TotalResult)

def plot33():
    Epsilon = np.flip(np.array([np.exp(-2.3 - i * 0.1) for i in range(30)]))
    myresult = Parallel(n_jobs=12)(delayed(Simu_onelyer)(i) for i in Epsilon)
    PDPi = [i[5] for i in myresult]
    PDPi_old = PDP_matrix_old.std(axis=1) / PDP_matrix_old.std(axis=1).sum()
    PDPi_scale = np.array(PDPi) / np.array(PDPi).sum(axis=1).reshape((-1, 1))
    fontsize = 25
    linewidth = 4
    fig, ax = plt.subplots()
    fig.set_size_inches(12, 8)
    box = ax.get_position()
    ax.plot(Epsilon, PDPi_scale[:, 0], c='b', label='learned policy', linewidth=linewidth)
    ax.plot(Epsilon, PDPi_scale[:, 1], c='r', label='learned policy',linewidth=linewidth)
    ax.plot(Epsilon, PDPi_scale[:, 2], c='g', label='learned policy', linewidth=linewidth)
    ax.plot(Epsilon, np.repeat(PDPi_old[0], Epsilon.__len__()), c='b', linestyle='--', linewidth=linewidth,
            label='baseline policy')
    ax.plot(Epsilon, np.repeat(PDPi_old[1], Epsilon.__len__()), c='r', linestyle='--', linewidth=linewidth,
            label='baseline policy')
    ax.plot(Epsilon, np.repeat(PDPi_old[2], Epsilon.__len__()), c='g', linestyle='--', linewidth=linewidth,
            label='baseline policy')
    ax.text(0.005, 0.48, 'Military', c='b', fontsize=fontsize)
    ax.text(0.005, 0.31, 'Political', c='r', fontsize=fontsize)
    ax.text(0.005, 0.187, 'Socioeconomic', c='g', fontsize=fontsize)
    ax.set_xlabel('Epsilon in the Bayesian safe policy learning', fontsize=fontsize)
    ax.set_ylabel('Scaled PD importance', fontsize=fontsize)
    ax.tick_params(axis='both', which='major', labelsize=fontsize)
    plt.xscale('log')
    fig.tight_layout()
    fig.show()
    fig.savefig('PolicyLearning/fig1/f617.pdf')

plot33()